import numpy as np
import random
from sceneprogllm import LLM

class Raytracer2D:
    def __init__(self, group):
        self.objects = group.children
        self.aabb = {obj: self.enhance_aabb(obj.get_aabb()) for obj in self.objects}
        self.MAX = 1000

        self.has_walls = hasattr(group, 'WIDTH') and hasattr(group, 'DEPTH')

        if hasattr(group, 'WIDTH'):
            self.WIDTH = group.WIDTH
        else:
            self.WIDTH = self.MAX
        if hasattr(group, 'DEPTH'):
            self.DEPTH = group.DEPTH
        else:
            self.DEPTH = self.MAX

    def enhance_aabb(self, aabb):
        # Enhance the AABB by adding a small margin
        margin = 0.05
        return np.array([[aabb[0][0] - margin, aabb[0][1] - margin, aabb[0][2] - margin],
                         [aabb[1][0] + margin, aabb[1][1] + margin, aabb[1][2] + margin]])
    
    def overlap(self, obj1, obj2):
        aabb1 = self.aabb[obj1]
        aabb2 = self.aabb[obj2]
        
        return self.aabb_overlap(aabb1, aabb2)
    
    def aabb_overlap(self, aabb1, aabb2):
        # Check for overlap
        if (
            aabb1[1][0] <= aabb2[0][0] or aabb2[1][0] <= aabb1[0][0] or  # x
            aabb1[1][1] <= aabb2[0][1] or aabb2[1][1] <= aabb1[0][1] or  # y
            aabb1[1][2] <= aabb2[0][2] or aabb2[1][2] <= aabb1[0][2]     # z
        ):
            return False, 0.0

        # Overlapping volume
        overlap_x = max(0.0, min(aabb1[1][0], aabb2[1][0]) - max(aabb1[0][0], aabb2[0][0]))
        overlap_y = max(0.0, min(aabb1[1][1], aabb2[1][1]) - max(aabb1[0][1], aabb2[0][1]))
        overlap_z = max(0.0, min(aabb1[1][2], aabb2[1][2]) - max(aabb1[0][2], aabb2[0][2]))
        overlap_volume = overlap_x * overlap_y * overlap_z

        # Volumes of both boxes
        vol1 = np.prod(aabb1[1] - aabb1[0])
        vol2 = np.prod(aabb2[1] - aabb2[0])

        # Overlap degree as a ratio to average volume
        avg_volume = (vol1 + vol2) / 2.0
        degree = overlap_volume / avg_volume if avg_volume > 0 else 0.0

        return True, degree
    
    def dist_in_xpos(self, obj1, obj2):
        aabb1 = self.aabb[obj1]
        aabb2 = self.aabb[obj2]

        xmin1, ymin1, zmin1 = aabb1[0]
        xmax1, ymax1, zmax1 = aabb1[1]
        xmin2, ymin2, zmin2 = aabb2[0]
        xmax2, ymax2, zmax2 = aabb2[1]

        if self.has_walls:
            if xmax1 > self.WIDTH:
                return 0.0
            MAX_DIST = self.WIDTH - xmax1
        else:
            MAX_DIST = self.MAX

        overlap, degree = self.aabb_overlap(aabb1, aabb2)
        if overlap:
            v1 = obj1.get_location()
            v2 = obj2.get_location()
            rel = v2 - v1
            rel = rel / (np.linalg.norm(rel) + 1e-3)
            if rel[0] > 0:
                return 0.0
            else:
                return MAX_DIST

        if zmin1 > zmax2 or zmin2 > zmax1:
            return MAX_DIST

        if xmax1 > xmax2:
            return MAX_DIST
        
        if xmin2 < xmax1:
            return 0.0
        
        return xmin2 - xmax1
    
    def dist_in_xneg(self, obj1, obj2):
        aabb1 = self.aabb[obj1]
        aabb2 = self.aabb[obj2]
        
        xmin1, ymin1, zmin1 = aabb1[0]
        xmax1, ymax1, zmax1 = aabb1[1]
        xmin2, ymin2, zmin2 = aabb2[0]
        xmax2, ymax2, zmax2 = aabb2[1]

        if self.has_walls:
            if xmin1 < 0:
                return 0.0
            MAX_DIST = xmin1
        else:
            MAX_DIST = self.MAX

        overlap, degree = self.aabb_overlap(aabb1, aabb2)
        if overlap:
            v1 = obj1.get_location()
            v2 = obj2.get_location()
            rel = v2 - v1
            rel = rel / (np.linalg.norm(rel) + 1e-3)
            if rel[0] < 0:
                return 0.0
            else:
                return MAX_DIST

        if zmin1 > zmax2 or zmin2 > zmax1:
            return MAX_DIST

        if xmin1 < xmin2:
            return MAX_DIST
        
        if xmax2 > xmin1:
            return 0.0
        
        return xmin1 - xmax2
    
    def dist_in_zpos(self, obj1, obj2):
        aabb1 = self.aabb[obj1]
        aabb2 = self.aabb[obj2]
        
        xmin1, ymin1, zmin1 = aabb1[0]
        xmax1, ymax1, zmax1 = aabb1[1]
        xmin2, ymin2, zmin2 = aabb2[0]
        xmax2, ymax2, zmax2 = aabb2[1]

        if self.has_walls:
            if zmax1 > self.DEPTH:
                return 0.0
            MAX_DIST = self.DEPTH - zmax1
        else:
            MAX_DIST = self.MAX

        overlap, degree = self.aabb_overlap(aabb1, aabb2)
        if overlap:
            v1 = obj1.get_location()
            v2 = obj2.get_location()
            rel = v2 - v1
            rel = rel / (np.linalg.norm(rel) + 1e-3)
            if rel[2] > 0:
                return 0.0
            else:
                return MAX_DIST

        if xmin1 > xmax2 or xmin2 > xmax1:
            return MAX_DIST

        if zmax1 > zmax2:
            return MAX_DIST
        
        if zmin2 < zmax1:
            return 0.0
        
        return zmin2 - zmax1
    
    def dist_in_zneg(self, obj1, obj2):
        aabb1 = self.aabb[obj1]
        aabb2 = self.aabb[obj2]
        
        xmin1, ymin1, zmin1 = aabb1[0]
        xmax1, ymax1, zmax1 = aabb1[1]
        xmin2, ymin2, zmin2 = aabb2[0]
        xmax2, ymax2, zmax2 = aabb2[1]

        if self.has_walls:
            if zmin1 < 0:
                return 0.0
            MAX_DIST = zmin1
        else:
            MAX_DIST = self.MAX

        overlap, degree = self.aabb_overlap(aabb1, aabb2)
        if overlap:
            v1 = obj1.get_location()
            v2 = obj2.get_location()
            rel = v2 - v1
            rel = rel / (np.linalg.norm(rel) + 1e-3)
            if rel[0] < 0:
                return 0.0
            else:
                return MAX_DIST

        if xmin1 > xmax2 or xmin2 > xmax1:
            return MAX_DIST

        if zmin1 < zmin2:
            return MAX_DIST
        
        if zmax2 > zmin1:
            return 0.0
        
        return zmin1 - zmax2
    
    def compute_free_space(self, obj, dir):
        dist = self.MAX
        nearest_obj = None
        if dir not in ['x+', 'x-', 'z+', 'z-']:
            raise ValueError("Invalid direction. Use 'x+', 'x-', 'z+', or 'z-'.")
        if dir == 'x+':
            dist = self.MAX
            for other in self.objects:
                if other is obj:
                    continue
                tmp = self.dist_in_xpos(obj, other)
                if tmp < dist:
                    dist = tmp
                    nearest_obj = other

        elif dir == 'x-':
            dist = self.MAX
            for other in self.objects:
                if other is obj:
                    continue
                tmp = self.dist_in_xneg(obj, other)
                if tmp < dist:
                    dist = tmp
                    nearest_obj = other

        elif dir == 'z+':
            dist = self.MAX
            for other in self.objects:
                if other is obj:
                    continue
                tmp = self.dist_in_zpos(obj, other)
                if tmp < dist:
                    dist = tmp
                    nearest_obj = other

        elif dir == 'z-':
            dist = self.MAX
            for other in self.objects:
                if other is obj:
                    continue
                tmp = self.dist_in_zneg(obj, other)
                if tmp < dist:
                    dist = tmp
                    nearest_obj = other

        return dist, nearest_obj
        
    def compute_free_space_all(self, obj):
        distances = {}
        distances['dx+'] = self.compute_free_space(obj, 'x+')[0]
        distances['dx-'] = self.compute_free_space(obj, 'x-')[0]
        distances['dz+'] = self.compute_free_space(obj, 'z+')[0]
        distances['dz-'] = self.compute_free_space(obj, 'z-')[0]

        return distances


class ConstraintBase:
    def __init__(self, name, group):
        self.name = name
        self.group = group

        if self.type == 'GRADIENT':
            if self not in self.group.grad_constraints:
                self.group.grad_constraints.append(self)
        elif self.type == 'VLM':
            if self not in self.group.vlm_constraints:
                self.group.vlm_constraints.append(self)
        else:
            raise ValueError("Constraint type must be 'GRADIENT' or 'VLM'")
        
    def is_aligned_zpos(self, obj):
        rotation = obj.get_rotation()
        if np.abs(rotation)%180 < 45:
            return True
        return False
    
    def is_aligned_xpos(self, obj):
        rotation = obj.get_rotation()
        if np.abs(rotation - 90)%180 < 45:
            return True
        return False
    
    def is_aligned_zneg(self, obj):
        rotation = obj.get_rotation()
        if np.abs(rotation - 180)%180 < 45:
            return True
        return False
    
    def is_aligned_xneg(self, obj):
        rotation = obj.get_rotation()
        if np.abs(rotation - 270)%180 < 45:
            return True
        return False

class OverlapConstraint(ConstraintBase):
    def __init__(self, group):
        self.name = 'OverlapConstraint'
        self.type = 'GRADIENT'
        self.weight = 1.0
        super().__init__(self.name, group)

    def compute_gradients(self):
        raytracer = Raytracer2D(self.group)
        objects = self.group.children
        for i in range(len(objects)):
            for j in range(i + 1, len(objects)):
                obj1 = objects[i]
                obj2 = objects[j]
                status, degree = raytracer.overlap(obj1, obj2)
                if status:
                    v1 = obj1.get_location()
                    v2 = obj2.get_location()
                    grad1 = (v1 - v2) * degree
                    grad2 = (v2 - v1) * degree

                    obj1.grad += grad1 * self.weight
                    obj2.grad += grad2 * self.weight

class ClearanceConstraint(ConstraintBase):
    def __init__(self, group, obj, distance=0.5, dir='front', omit_objs=[]):
        self.name = 'ClearanceConstraint'
        self.type = 'GRADIENT'
        self.weight = 1.0
        # assert obj in group.children, "Object must be part of the group"
        self.obj = obj
        self.distance = distance
        self.omit_objs = omit_objs
        assert dir in ['front', 'sides', 'all'], "Type must be 'front', 'sides', or 'all'"
        self.dir = dir
        # self.objects = [o for o in group.children if o not in omit_objs]

        super().__init__(self.name, group)

    def compute_gradients(self):
        
        raytracer = Raytracer2D(self.group)
        if self.dir == 'front' or self.dir == 'all':
            if self.is_aligned_zpos(self.obj):
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, -delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, delta/2]) * self.weight
            elif self.is_aligned_zneg(self.obj):
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, -delta/2]) * self.weight
            elif self.is_aligned_xpos(self.obj):
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([-delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([delta/2, 0, 0]) * self.weight
            elif self.is_aligned_xneg(self.obj):
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([-delta/2, 0, 0]) * self.weight
        
        if self.dir == 'sides' or self.dir == 'all':
            if self.is_aligned_zpos(self.obj) or self.is_aligned_zneg(self.obj):
                ## right side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([-delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([delta/2, 0, 0]) * self.weight
                ## left side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([-delta/2, 0, 0]) * self.weight
            
            elif self.is_aligned_xpos(self.obj) or self.is_aligned_xneg(self.obj):
                ## front side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, -delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, delta/2]) * self.weight
                ## back side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, -delta/2]) * self.weight

        if self.dir == 'all':
            if self.is_aligned_zpos(self.obj):
                ## back side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, -delta/2]) * self.weight
            elif self.is_aligned_zneg(self.obj):
                ## front side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'z+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([0, 0, -delta/2]) * self.weight
                    nearest_obj.grad += np.array([0, 0, delta/2]) * self.weight
            elif self.is_aligned_xpos(self.obj):
                ## left side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x-')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([-delta/2, 0, 0]) * self.weight
            elif self.is_aligned_xneg(self.obj):
                ## right side
                dist, nearest_obj = raytracer.compute_free_space(self.obj, 'x+')
                if dist < self.distance:
                    delta = self.distance - dist
                    self.obj.grad += np.array([-delta/2, 0, 0]) * self.weight
                    nearest_obj.grad += np.array([delta/2, 0, 0]) * self.weight
                    
class AccessConstraint(ConstraintBase):
    def __init__(self, group, obj, target, min_dist=0.1, max_dist=0.15, dir='front'):
        self.name = 'AccessConstraint'
        self.type = 'GRADIENT'
        self.weight = 1.0

        # assert obj in group.children, "Object must be part of the group"
        self.obj = obj
        self.other = target
        self.min_dist = min_dist
        self.max_dist = max_dist
        assert dir in ['front', 'sides'], "Type must be 'front' or 'sides'"
        self.dir = dir
        super().__init__(self.name, group)

    def compute_gradients(self):
        raytracer = Raytracer2D(self.group)
        x1min, _, z1min = raytracer.aabb[self.obj][0]
        x1max, _, z1max = raytracer.aabb[self.obj][1]
        x2min, _, z2min = raytracer.aabb[self.other][0]
        x2max, _, z2max = raytracer.aabb[self.other][1]
        # breakpoint()
        if self.dir == 'sides':
            if self.is_aligned_zpos(self.obj) or self.is_aligned_zneg(self.obj):
                if z1min < z2min and z1max > z2max:
                    if x1max <= x2min:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = x2min - x1max
                        if dist < self.min_dist:
                            delta = self.min_dist - dist
                            self.obj.grad -= rel * delta * self.weight
                            self.other.grad += rel * delta * self.weight
                        elif dist > self.max_dist:
                            delta = dist - self.max_dist 
                            self.obj.grad += rel * delta * self.weight
                            self.other.grad -= rel * delta * self.weight
                    elif x1min >= x2max:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = x1min - x2max
                        if dist < self.min_dist:
                            delta = self.min_dist - dist
                            self.obj.grad -= rel * delta * self.weight
                            self.other.grad += rel * delta * self.weight
                        elif dist > self.max_dist:
                            delta = dist - self.max_dist 
                            self.obj.grad += rel * delta * self.weight
                            self.other.grad -= rel * delta * self.weight

                if x1max < x2min or x1min > x2max:
                    if z1min >= z2min:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = (z1max+z1min)/2 - z2max
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif z1max < z2max:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = z2min - (z1max+z1min)/2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                
                elif x2max > x1min or x2min < x1max:
                    if (x1max + x1min) / 2 <= (x2max + x2min) / 2:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = (x1max - x1min) / 2 + (x2max - x2min) / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif (x1max + x1min) / 2 > (x2max + x2min) / 2:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = (x1max - x1min) / 2 + (x2max - x2min) / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

            elif self.is_aligned_x(self.obj) or self.is_aligned_xneg(self.obj):
                if x1min < x2min and x1max > x2max:
                    if z1max <= z2min:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = z2min - z1max
                        if dist < self.min_dist:
                            delta = self.min_dist - dist
                            self.obj.grad -= rel * delta * self.weight
                            self.other.grad += rel * delta * self.weight
                        elif dist > self.max_dist:
                            delta = dist - self.max_dist 
                            self.obj.grad += rel * delta * self.weight
                            self.other.grad -= rel * delta * self.weight
                    elif z1min >= z2max:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = z1min - z2max
                        if dist < self.min_dist:
                            delta = self.min_dist - dist
                            self.other.grad += rel * delta * self.weight
                            self.obj.grad -= rel * delta * self.weight
                        elif dist > self.max_dist:
                            delta = dist - self.max_dist 
                            self.other.grad -= rel * delta * self.weight
                            self.obj.grad += rel * delta * self.weight
                            

                if z1max < z2min or z1min > z2max:
                    if x1min >= x2min:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = (x1max+x1min)/2 - x2max
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                    elif x1max < x2max:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = x2min - (x1max+x1min)/2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                elif z2max > z1min or z2min < z1max:
                    if (z1max + z1min) / 2 <= (z2max + z2min) / 2:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = (z1max - z1min) / 2 + (z2max - z2min) / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                        
                    elif (z1max + z1min) / 2 > (z2max + z2min) / 2:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = (z1max - z1min) / 2 + (z2max - z2min) / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                        

        elif self.dir == 'front':
            if self.is_aligned_zpos(self.obj):
                if z2min <= z1max:
                    if (x2min+ x2max) / 2 < (x1min + x1max) / 2 and x2max > x1min:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2 
                        self.other.grad += rel * dist * self.weight
                    elif (x2min + x2max) / 2 > (x1min + x1max) / 2 and x2min < x1max:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                    else:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                    
                else:
                    if x1min >= (x2max+x2min) / 2:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif x1max <= (x2min+x2max) / 2:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                    
                    else:
                        dist = z2min - z1max
                        if dist < self.min_dist:
                            rel = np.array([0, 0, 1], dtype=np.float32)
                            delta = self.min_dist - dist
                            self.other.grad += rel * delta * self.weight
                            self.obj.grad -= rel * delta * self.weight
                        elif dist > self.max_dist:
                            rel = np.array([0, 0, 1], dtype=np.float32)
                            delta = dist - self.max_dist
                            self.other.grad -= rel * delta * self.weight
                            self.obj.grad += rel * delta * self.weight

                    
            elif self.is_aligned_xpos(self.obj):
                if x2min <= x1max:
                    if (z2min+ z2max) / 2 < (z1min + z1max) / 2 and z2max > z1min:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2 
                        self.other.grad += rel * dist * self.weight
                    elif (z2min + z2max) / 2 > (z1min + z1max) / 2 and z2min < z1max:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                    else:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                    
                else:
                    if z1min >= (z2max+z2min) / 2:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif z1max <= (z2min+z2max) / 2:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                    
                    else:
                        dist = x2min - x1max
                        if dist < self.min_dist:
                            rel = np.array([1, 0, 0], dtype=np.float32)
                            delta = self.min_dist - dist
                            self.other.grad += rel * delta * self.weight
                            self.obj.grad -= rel * delta * self.weight
                        elif dist > self.max_dist:
                            rel = np.array([1, 0, 0], dtype=np.float32)
                            delta = dist - self.max_dist
                            self.other.grad -= rel * delta * self.weight
                            self.obj.grad += rel * delta * self.weight

            elif self.is_aligned_zneg(self.obj):
                if z1min >= z2max:
                    if (x2min+ x2max) / 2 < (x1min + x1max) / 2 and x2min > x1max:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2 
                        self.other.grad += rel * dist * self.weight
                    elif (x2min + x2max) / 2 > (x1min + x1max) / 2 and x2max < x1min:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad -= rel * dist * self.weight
                    else:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                    
                else:
                    if x1min >= (x2max+x2min) / 2:
                        rel = np.array([1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif x1max <= (x2min+x2max) / 2:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                    
                    else:
                        dist = z1min - z2max
                        if dist < self.min_dist:
                            rel = np.array([0, 0, -1], dtype=np.float32)
                            delta = self.min_dist - dist
                            self.other.grad += rel * delta * self.weight
                            self.obj.grad -= rel * delta * self.weight
                        elif dist > self.max_dist:
                            rel = np.array([0, 0, -1], dtype=np.float32)
                            delta = dist - self.max_dist
                            self.other.grad -= rel * delta * self.weight
                            self.obj.grad += rel * delta * self.weight
                            
            elif self.is_aligned_xneg(self.obj):
                if x2max >= x1min:
                    if (z2min+ z2max) / 2 < (z1min + z1max) / 2 and z2max > z1min:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2 
                        self.other.grad += rel * dist * self.weight
                    elif (z2min + z2max) / 2 > (z1min + z1max) / 2 and z2min < z1max:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                    else:
                        rel = np.array([-1, 0, 0], dtype=np.float32)
                        dist = self.obj.get_width() / 2 + self.other.get_width() / 2
                        self.other.grad += rel * dist * self.weight
                    
                else:
                    if z1min >= (z2max+z2min) / 2:
                        rel = np.array([0, 0, 1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight

                    elif z1max <= (z2min+z2max) / 2:
                        rel = np.array([0, 0, -1], dtype=np.float32)
                        dist = self.obj.get_depth() / 2 + self.other.get_depth() / 2
                        self.other.grad += rel * dist * self.weight
                        self.obj.grad -= rel * dist * self.weight
                    
                    else:
                        dist = x1min - x2max
                        if dist < self.min_dist:
                            rel = np.array([-1, 0, 0], dtype=np.float32)
                            delta = self.min_dist - dist
                            self.other.grad += rel * delta * self.weight
                            self.obj.grad -= rel * delta * self.weight
                        elif dist > self.max_dist:
                            rel = np.array([-1, 0, 0], dtype=np.float32)
                            delta = dist - self.max_dist
                            self.other.grad -= rel * delta * self.weight
                            self.obj.grad += rel * delta * self.weight

class OutOfBoundsConstraint(ConstraintBase):
    def __init__(self, group):
        self.name = 'OutOfBoundsConstraint'
        self.type = 'GRADIENT'
        self.weight = 1.0
        assert hasattr(group, 'WIDTH') and hasattr(group, 'DEPTH'), "Group must have WIDTH and DEPTH attributes"

        self.WIDTH = group.WIDTH
        self.DEPTH = group.DEPTH
        super().__init__(self.name, group)

    def compute_gradients(self):
        raytracer = Raytracer2D(self.group)
        objects = self.group.children
        BUFFER=0.1
        for obj in objects:
            aabb = raytracer.aabb[obj]
            xmin, ymin, zmin = aabb[0]
            xmax, ymax, zmax = aabb[1]

            # if xmin < 0 or xmax > self.WIDTH or zmin < 0 or zmax > self.DEPTH:
            #     breakpoint()

            if xmin < 0:
                dist = -xmin + BUFFER
                obj.grad += np.array([dist, 0, 0], dtype=np.float32) * self.weight
            elif xmax > self.WIDTH:
                dist = xmax - self.WIDTH + BUFFER
                obj.grad += np.array([-dist, 0, 0], dtype=np.float32) * self.weight

            if zmin < 0:
                dist = -zmin + BUFFER
                obj.grad += np.array([0, 0, dist], dtype=np.float32) * self.weight
            elif zmax > self.DEPTH:
                dist = zmax - self.DEPTH + BUFFER
                obj.grad += np.array([0, 0, -dist], dtype=np.float32) * self.weight

class VisibilityConstraint(ConstraintBase):
    def __init__(self, group, source, target):
        self.name = 'VisibilityConstraint'
        self.type = 'GRADIENT'
        self.weight = 1.0
        self.source = source
        self.target = target

        super().__init__(self.name, group)

    def is_point_in_trapezoid(self, point, trapezoid):
        def is_left(p0, p1, p2):
            """Check if p2 is left of the line from p0 to p1."""
            return (p1[0] - p0[0])*(p2[1] - p0[1]) - (p2[0] - p0[0])*(p1[1] - p0[1]) > 0

        inside = True
        for i in range(4):
            a = trapezoid[i]
            b = trapezoid[(i + 1) % 4]
            if not is_left(a, b, point):
                inside = False
                break
        return inside
    
    def build_trapezoid(self, source, target):
        aabb1 = source.get_aabb()
        aabb2 = target.get_aabb()
        x1min, y1min, z1min = aabb1[0]
        x1max, y1max, z1max = aabb1[1]
        v1 = source.get_location()

        x2min, y2min, z2min = aabb2[0]
        x2max, y2max, z2max = aabb2[1]
        v2 = target.get_location()

        rel = v2 - v1
        rel = rel / (np.linalg.norm(rel) + 1e-3)
        rel = rel[[0,2]]
        
        if rel@np.array([0,1]) > 0.9:
            trapezoid = np.array([
                [x1min, z1max],
                [x1max, z1max],
                [x2max, z2min],
                [x2min, z2min]
            ], dtype=np.float32)

        elif rel@np.array([1,0]) > 0.9:
            trapezoid = np.array([
                [x1max, z1min],
                [x2min, z2min],
                [x2min, z2max],
                [x1max, z1max]
            ], dtype=np.float32)

        elif rel@np.array([0, -1]) > 0.9:
            trapezoid = np.array([
                [x1min, z1min],
                [x2min, z2max],
                [x2max, z2max],
                [x1max, z1max]
            ], dtype=np.float32)
        elif rel@np.array([-1, 0]) > 0.9:
            trapezoid = np.array([
                [x1min, z1min],
                [x1min, z1max],
                [x2max, z2max],
                [x2max, z2min]
            ], dtype=np.float32)
        else:
            breakpoint()
            raise ValueError("Invalid direction for trapezoid construction")
        
        return trapezoid
    
    def is_obj_in_trapezoid(self, obj_aabb, trapezoid):
        pts = np.array([
            [obj_aabb[0,0], obj_aabb[0,2]],
            [obj_aabb[1,0], obj_aabb[0,2]],
            [obj_aabb[1,0], obj_aabb[1,2]],
            [obj_aabb[0,0], obj_aabb[1,2]]
        ], dtype=np.float32)
        for loc in pts:
            if self.is_point_in_trapezoid(loc, trapezoid):
                return True
        return False

    def projection_vector(self, v1, v2, v3):
        v1 = v1[[0, 2]]  # Use only x and z coordinates
        v2 = v2[[0, 2]]  # Use only x and z coordinates
        v3 = v3[[0, 2]]  # Use only x and z coordinates
        d = v2 - v1           # Direction vector of the line
        a = v3 - v1           # Vector from v1 to v3
        d_norm_sq = np.dot(d, d)
        
        if d_norm_sq == 0:
            raise ValueError("v1 and v2 cannot be the same point")
        
        proj_scalar = np.dot(a, d) / d_norm_sq
        v_proj = v1 + proj_scalar * d  # Projection of v3 onto the line
        vt = v3 - v_proj               # Desired vector from projection to v3
        return vt

    def compute_gradients(self):
        raytracer = Raytracer2D(self.group)
        trapezoid = self.build_trapezoid(self.source, self.target)
        v1 = self.source.get_location()
        v2 = self.target.get_location()
        for obj in self.group.children:
            if obj is self.source or obj is self.target:
                continue
            aabb = raytracer.aabb[obj]
            if self.is_obj_in_trapezoid(aabb, trapezoid):
                v3 = obj.get_location()
                vt = self.projection_vector(v1, v2, v3)
                vt = np.array([vt[0], 0, vt[1]], dtype=np.float32)  # Project to x-z plane
                mag = 10/ (np.linalg.norm(vt) if np.linalg.norm(vt) > 1e-3 else 1.0)
                if np.linalg.norm(vt) < 1e-3:
                    rel = (v2 - v1) / (np.linalg.norm(v2 - v1) + 1e-3)
                    if rel@ np.array([0, 0, 1]) > 0.9 or rel@ np.array([0, 0, -1]) > 0.9:
                        vt = random.choice([np.array([1, 0, 0]), np.array([-1, 0, 0])])
                    else:
                        vt = random.choice([np.array([0, 0, 1]), np.array([0, 0, -1])])

                obj.grad += mag* vt * self.weight

class RenderingConstraint(ConstraintBase):
    def __init__(self, group, wall, paintings, target_image_path):
        from painting_detector import PaintingDetector
        self.name = 'RenderingConstraint'
        self.description = f"""
Helps in optimizing placement of paintings on the wall to match a given image target.
Inputs:
- wall: The wall name (str)
- paintings: List of painting objects (list)
- target_image_path: Path to the target image (str)
"""
        self.examples = f"""
with scene.RoomGroup() as room:
    ...
    painting = scene.AddAsset("A Beautiful Landscape")
    paintings = 3*paintings
    room.place_on_wall_back_left(paintings[0])
    room.place_on_wall_back_center(paintings[1])
    room.place_on_wall_back_right(paintings[2])

    room.RenderingConstraint("back_wall", paintings, "path/to/target/image.jpg")
"""

        self.type = 'GRADIENT'
        self.painting_detector = PaintingDetector()
        self.target_image_path = target_image_path

        self.target_centroids, self.target_bbox = self.painting_detector(self.target_image_path, resize=(1920,1080))
        self.wall = wall
        self.paintings = paintings

        super().__init__(self.name, group)

    def compute_gradients(self):
        ## Render the wall with paintings
        current_image_path = self.group.render_wall(self.wall, self.paintings)
        ## Detect centroids of each painting using Owlv2
        centroids, tmp = self.painting_detector(current_image_path, resize=(1920,1080))
        
        ## Use hungarian method to derieve optimal 1-1 mapping between centroids. 
        perm = self.painting_detector.compute_mapping(centroids, self.target_centroids)
        mapped_centroids = [self.target_centroids[i] for i in perm]

        for i, painting in enumerate(self.paintings):
            grad = mapped_centroids[i] - centroids[i] ## pseudo gradient
            img_grad[0] *= 1/1920
            img_grad[1] *= 1/1080
            painting.grad += np.array([img_grad[0], img_grad[1], 0], dtype=np.float32)
            
class ObjectProportionsConstraint(ConstraintBase):
    def __init__(self, group):
        self.name = 'ObjectProportionsConstraint'
        self.type = 'VLM'

        self.VLM = LLM(
            name='ObjectProportionsConstraint',
            system_desc=f"""
You are given an image which shows front, right, back, left renders of a scene which consists of several objects. 
You are required to reason about the relative proportions of the objects in the scene and to ensure that they make sense in the context of the scene.
For example, if the coffee table placed next to the sofa is too big, then you should respond with a rescale instruction for the coffee table. Like 'rescale coffee table by 0.5'. 
The rescale factor should be a float number between 0.1 and 0.9. Note that in most cases, rescaling isn't necessary, so only respond with a rescale instruction if you think it is necessary.
An example of a rescale instruction is:

```
- rescale coffee table by 0.5
- rescale sofa by 0.8
```
Note that not all objects were asked to be rescaled, only the ones that are too big or too small in the scene.
In case you think that the objects are already in appropriate proportions, then just say 'no rescale'.
You should not respond with any other instructions, only rescale instructions.
"""
        )
        super().__init__(self.name, group)

    def compute_gradients(self):
        render = self.group.render()
        descriptions = self.group.get_descriptions()
        
        prompt = f"""
The scene has the following objects: {",".join(descriptions)}
Now reason about the relative proportions of the objects in the scene and to ensure that they make sense in the context of the scene.
"""
        response = self.VLM(prompt, image_paths=[render])
        return response
    
class RoomProportionsConstraint(ConstraintBase):
    def __init__(self, group):
        self.name = 'RoomProportionsConstraint'
        self.type = 'VLM'

        self.VLM = LLM(
            name='RoomProportionsConstraint',
            system_desc=f"""
You are given an image which shows front, right, back, left renders of a scene which consists of several objects. 
You are required to reason about the relative proportion (size) of the room in the scene and to ensure that it makes sense in the context of the scene. 
The room shouldn't feel too spacious or too cramped. You should respond with a rescale instruction so that the room feels more appropriate.
The rescale factor should be a float number between 0.5 to 2.0. 
Additionally, you should also consider the occupancy ratio of the room (provided to you) when making your assessment. Ideally aim for an occupancy ratio of around 0.4. So if the occupancy ratio is too low, you should rescale the room to be smaller, and if it is too high, you should rescale the room to be larger.
For example, if the room is too spacious (occupancy ratio like 0.2), then a rescale instruction like 'rescale room by 0.8' can be given based on how spacious the room currently is. 
An example of a rescale instruction is:
```
- rescale room by 0.8
```
Note that you should only respond with a rescale instruction if you think it is necessary. Otherwise, just leave the room as it is by saying 'no rescale'.
"""
        )
        super().__init__(self.name, group)

    def compute_gradients(self):
        render = self.group.render()
        occupancy_ratio = self.group.compute_occupancy_ratio()
        descriptions = self.group.get_descriptions()
        prompt = f"""
The scene has the following objects: {",".join(descriptions)}
The occupancy ratio of the room is {occupancy_ratio:.2f}.
Now reason about the relative proportion (size) of the room in the scene and to ensure that it makes sense in the context of the scene.
"""
        response = self.VLM(prompt, image_paths=[render])
        return response
         
CONSTRAINTS = [
    OverlapConstraint,
    ClearanceConstraint,
    AccessConstraint,
    OutOfBoundsConstraint,
    VisibilityConstraint,
    ObjectProportionsConstraint,
    RoomProportionsConstraint,
    RenderingConstraint
]

class VLMSolver:
    def __init__(self, group):
        self.group = group
        self.objects = group.children
        self.constraints = group.vlm_constraints
        self.scene = group.scene

    def __call__(self):
        for constraint in self.constraints:
            response = constraint.compute_gradients()
            self.scene.vlm_feedback += "\n" + response

class GradSolver:
    def __init__(self, group):
        self.group = group
        self.objects = group.children
        self.lr = 0.1

    def init_gradients(self):
        for obj in self.objects:
            obj.grad = np.zeros(3, dtype=np.float32)

    def compute_gradients(self):
        for constraint in self.group.grad_constraints:
            constraint.compute_gradients()

        avg_grad_norm = 0
        for obj in self.objects:
            if np.linalg.norm(obj.grad) < 1e-3:
                obj.grad = np.zeros(3, dtype=np.float32)
            avg_grad_norm += np.linalg.norm(obj.grad)

        avg_grad_norm /= len(self.objects)
        return avg_grad_norm

    def free_space_affinity(self, space):
        return 1-np.exp(-2*space)
        
    def sample_k_indices_from_pdf(self, pdf, k=2):
        flat_pdf = pdf.ravel()
        flat_pdf /= flat_pdf.sum()

        try:
            # Sample `k` unique indices from flattened PDF
            flat_indices = np.random.choice(len(flat_pdf), size=k, replace=False, p=flat_pdf)
        except:
            flat_indices = np.random.choice(len(flat_pdf), size=k, replace=True, p=flat_pdf)
        # Convert flat indices to (row, col) tuples
        n_cols = pdf.shape[1]
        indices = [(idx // n_cols, idx % n_cols) for idx in flat_indices]

        return indices

    def tidy_array(self, arr):
        arr = (100 * arr).astype(np.int32)
        arr = arr / 100.0
        return arr

    def compute_action(self):
        raytracer = Raytracer2D(self.group)
        def do_for(obj):
            distances = raytracer.compute_free_space_all(obj)
            dxpos = self.free_space_affinity(distances['dx+'])
            dxneg = self.free_space_affinity(distances['dx-'])
            dzpos = self.free_space_affinity(distances['dz+'])
            dzneg = self.free_space_affinity(distances['dz-'])

            xpos = np.max((obj.grad.dot(np.array([1, 0, 0], dtype=np.float32)), 0.01))* dxpos / obj.get_area()
            zpos = np.max((obj.grad.dot(np.array([0, 0, 1], dtype=np.float32)), 0.01))* dzpos / obj.get_area()
            xneg = np.max((obj.grad.dot(np.array([-1, 0, 0], dtype=np.float32)), 0.01))* dxneg / obj.get_area()
            zneg = np.max((obj.grad.dot(np.array([0, 0, -1], dtype=np.float32)), 0.01))* dzneg / obj.get_area()

            return xpos, zpos, xneg, zneg, dxpos, dzpos, dxneg, dzneg

        actions = []
        dists = []
        for obj in self.objects:
            xpos, zpos, xneg, zneg, dxpos, dzpos, dxneg, dzneg = do_for(obj)
            actions.append(np.array([xpos, zpos, xneg, zneg]).reshape(1, 4))
            dists.append(np.array([dxpos, dzpos, dxneg, dzneg]).reshape(1, 4))
        
        actions = np.concatenate(actions, axis=0)

        actions = actions ** 3
        try:
            actions /= np.sum(actions)
            dists = np.concatenate(dists, axis=0)
            indices = self.sample_k_indices_from_pdf(actions)
        except:
            breakpoint()

        objects = []
        for row, col in indices:
            obj = self.objects[row]
            if col == 0:
                action = np.abs(obj.grad[0]) * np.array([1, 0, 0], dtype=np.float32)
            elif col == 1:
                action = np.abs(obj.grad[2]) * np.array([0, 0, 1], dtype=np.float32)
            elif col == 2:
                action = np.abs(obj.grad[0]) * np.array([-1, 0, 0], dtype=np.float32)
            elif col == 3:
                action = np.abs(obj.grad[2]) * np.array([0, 0, -1], dtype=np.float32)

            objects.append((obj, action))

        return objects, actions, dists

    def __call__(self):
        for i in range(10):
            self.init_gradients()
            avg_grad_norm = self.compute_gradients()
            objects, actions, dists = self.compute_action()
            # if i%10 == 0:
            #     print("*****************************************************************************")
            #     print(f"Iteration {i+1}:\n", "\n".join([f"{o.name} grad: {self.tidy_array(o.grad)} " for o in self.objects]))
            #     print("Action PDF: \n", self.tidy_array(actions))
            #     print("Distances PDF: \n", self.tidy_array(dists))
            #     for obj, action in objects:
            #         print("Moving", obj.name, " by ", self.tidy_array(action))
            #     print("*****************************************************************************")
            for obj, action in objects:
                action = action * self.lr
                obj.translate(action[0], action[1], action[2])